Skip to content

[weather] Add Fgn model #1660

Open
kashif wants to merge 56 commits into
NVIDIA:mainfrom
kashif:fgn
Open

[weather] Add Fgn model #1660
kashif wants to merge 56 commits into
NVIDIA:mainfrom
kashif:fgn

Conversation

@kashif

@kashif kashif commented May 21, 2026

Copy link
Copy Markdown

PhysicsNeMo Pull Request

Description

Adds the FGN (Functional Generative Networks) weather model example from arXiv:2506.10772.

FGN is a latent-conditioned UNet trained with fair-CRPS to generate calibrated ensemble forecasts. This PR includes:

  • examples/weather/fgn/ — training, evaluation, and inference scripts
  • Fair-CRPS loss, autoregressive rollout up to 8 steps (paper Table A.2 schedule), and deep-ensemble inference (§2.2.1)
  • Per-variable CRPS / RMSE / spread-skill / rank histogram / 1D power spectra diagnostics matching paper Figures 2–3
  • Hydra config with AdamW (lr=8e-5, weight_decay=0.1) + cosine LR schedule with warmup per Table A.2
  • LaunchLogger / W&B integration following the PhysicsNeMo convention

Checklist

Dependencies

No new dependencies beyond what is already in the PhysicsNeMo environment.

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI's assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

@copy-pr-bot

copy-pr-bot Bot commented May 21, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented May 21, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a new examples/weather/fgn/ example implementing the Functional Generative Network (FGN) weather model scaffold in PhysicsNeMo, covering training, inference, datasets, loss/metrics utilities, and SLURM scripts. Two P1 bugs need attention before merging.

  • Inference rollout bug: _rollout in inference.py uses rollout_history[:, 1] which hardcodes history_frames=2; any other history size causes a shape mismatch runtime error on the second step.
  • AMP silently ignored: TrainingConfig.amp is declared and the production shell script passes training.amp=true, but the trainer never reads this flag — no autocast or GradScaler is applied.
  • Personal paths in SLURM scripts: /mnt/home/kashif/ and /mnt/data/kashif/ are hardcoded in scripts/train_2024_val.sh, compute_stats_2024.sh, and prefetch_arco_2024.sh, making them unusable for any other user.

Important Files Changed

Filename Overview
examples/weather/fgn/inference.py Adds autoregressive inference with deep-ensemble support; contains a P1 bug in _rollout where rollout_history[:, 1] hardcodes history_frames=2, breaking at runtime for any other history window size.
examples/weather/fgn/utils/trainer.py Full training loop with AR rollout, validation metrics, and checkpoint management; AMP config field declared and used in production scripts but never implemented in the training loop.
examples/weather/fgn/utils/loss.py Implements fair CRPS (eq. 4-5 of the FGN paper) and GraphCast-style channel weights with correct geopotential halving; logic appears sound and well-tested.
examples/weather/fgn/utils/nn.py Latent-conditioned U-Net (FGNUNet) scaffold with conditional residual blocks; straightforward implementation, no issues found.
examples/weather/fgn/utils/parallel.py FSDP + ShardTensor data/domain parallelism helper adapted from StormCast; sharded dataloader and nested scatter logic look correct.
examples/weather/fgn/datasets/arco.py ERA5/ARCO dataset wrapper with tp accumulation, SST NaN imputation, and z-score normalization; num_samples formula is correct, SST impute logic is sound.
examples/weather/fgn/utils/metrics.py Validation diagnostics (CRPS, RMSE, spread-skill, rank histograms, power spectra, derived variables); logic appears correct, well-documented limitations noted.
examples/weather/fgn/scripts/train_2024_val.sh Production SLURM training script with personal home-directory paths hardcoded (/mnt/home/kashif/, /mnt/data/kashif/) that will fail for any other user without modification.
examples/weather/fgn/utils/config.py Pydantic dataclass configs for training/inference/model; amp: bool = False declared but the flag is never consumed by the trainer.
examples/weather/fgn/datasets/mock.py Synthetic smoke-test dataset with deterministic wave fields; no issues found.

Reviews (1): Last reviewed commit: "Merge branch 'main' into fgn" | Re-trigger Greptile

Comment thread examples/weather/fgn/inference.py Outdated
Comment thread examples/weather/fgn/utils/trainer.py
Comment thread examples/weather/fgn/scripts/train_2024_val.sh Outdated
Comment thread examples/weather/fgn/utils/trainer.py Outdated
Comment on lines +319 to +330
for k in range(ar_steps):
members = []
for n in range(num_samples):
hist_n = per_member_hist[:, n]
pred_n = self._step_ensemble(
history=hist_n,
background=background,
invariants=invariants,
num_samples=1,
)[:, 0]
members.append(pred_n)
preds = torch.stack(members, dim=1) # (B, N, C, H, W)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant wrapping of single model forward pass through _step_ensemble. The outer for n in range(num_samples) loop calls _step_ensemble(..., num_samples=1) for each member, which internally runs its own for _ in range(1) loop. This double-loops and makes the code harder to follow; calling the model directly (as the validation loop in _run_validation_metrics does) would be cleaner and avoids the vestigial num_samples=1 sentinel.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in the latest push. _step_ensemble was indeed dead code — _loss and _validation_loss both inline the ensemble loop directly. The double-loop pattern was from an earlier version.

kashif added 3 commits May 21, 2026 12:01
Remove cluster-specific slurm scripts (local paths), untrack FGN.md
(dev notes), add .gitignore, and fix README references.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…e indirection

- Fix inference _rollout: torch.cat([history[:, 1:], next_frame.unsqueeze(1)])
  so history window slides correctly for any history_frames value, not just 2
- Remove unimplemented amp config field from TrainingConfig and default.yaml
- Inline model call in _loss AR loop instead of routing through _step_ensemble
  with num_samples=1 (each member needs its own history, so the single-call
  collapse doesn't apply; direct call is cleaner)

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Run ruff format + fix across all fgn/ Python files
- Remove unused imports (Sequence, Callable, ShardTensor, math, torch)
- Replace assert with if/raise (S101), fix import order (I001),
  simplify loops to list-comprehension/extend (PERF401/102)
- Add noqa: E402 on intentional post-path-insert imports in stage4
- Upgrade FGNUNet docstring to MOD-003 (r-string, NumPy sections,
  Parameters/Forward/Outputs with LaTeX shapes, Examples)
- Add CHANGELOG.md entry under [2.1.0a0]

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
kashif added 14 commits May 21, 2026 12:18
…nfig

- utils/metrics.py: add energy_score_per_lead — fair energy score
  (multivariate CRPS) over the variable axis with spatial subsampling;
  new in earth2studio 0.13.0, captures cross-channel calibration
- utils/trainer.py: wire energy_score_per_lead into validation hook,
  save to metrics.npz and plot energy_score_vs_lead.png
- config/fgn.yaml: base Hydra config required by train.py
  (@hydra.main config_name="fgn") with model defaults and dataset
  skeleton; was missing, causing Hydra to error without all overrides
- config/fgn_arco.yaml: practical single-GPU ARCO ERA5 training config
  (2018–2022 train / 2023 val, hidden_channels=64, 5000 steps, full
  loss weights) for runs beyond the smoke-test default

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- datasets/__init__.py: auto-discovery registry (mirrors stormcast)
  that populates dataset_classes dict by scanning all FGNDataset
  subclasses; fixes ImportError since the regular HF `datasets`
  package beat the namespace package without __init__.py
- datasets/dataset.py: FGNDataset ABC (state_channels, background_channels,
  image_shape, get_invariants, output_only_channels) + worker_init;
  mirrors stormcast/datasets/dataset.py convention
- utils/loss.py: fair_crps (paper eq. 4), ensemble_mean_mse,
  build_channel_weights (§2.2.3 GraphCast scheme with z halved),
  build_area_weights (cos-lat normalised to unit mean)

All three files existed locally before the branch cleanup but were
never committed; this adds them properly.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Use _make_train_iter() to route through sharded_data_iter when domain
  parallelism is active (mirrors StormCast trainer pattern); plain DDP
  path gets an infinite-restart iterator instead of the old bare iter()
- Wrap both model forward sites in torch.autocast(bfloat16) and call
  .float() on preds to keep loss computation in fp32; halves activation
  memory at full 721x1440 resolution on H100 80GB
- train_fgn.sh: batch_size=1, domain_parallel_size=1 (DDP), run_id
  Hydra string quoting fix, PYTORCH_CUDA_ALLOC_CONF=expandable_segments

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
With domain_parallel_size=1 and 2 GPUs, data_parallel_size=2 so
local_batch = batch_size // 2; batch_size=1 → local_batch=0 causing
BatchSampler ValueError. Use batch_size=2 (global) = 1 per GPU.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Mark tp06, multi-rank sanity, AR stage scheduler, bad-seed detector
as done. Add status for currently running 5000-step job 99807.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Replaces the stale MVP-scaffold README with a full recipe README
modelled on stormcast/README.md: problem overview, dataset (ARCO),
getting started, configuration table, training (single-GPU / torchrun /
SLURM), AR fine-tuning schedule, inference, custom dataset interface,
memory guidance, and references.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Link to developers.google.com/weathernext/guides/models and
  model-specs-vmg from the intro and References section
- Clarify production deployment: 64 members (4 seeds × 16 each)
- Note u100m/v100m omission: ERA5/ARCO lacks 100m winds

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…elper

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
kashif added 21 commits May 29, 2026 21:06
Replace _ar_rollout (returns full (B,K,M,C,H,W) = ~51 GB) with
_ar_rollout_steps generator (yields (B,M,C,H,W) per step, ~2.5 GB).
All metrics computed per-step via unsqueeze(1) trick.
Mirrors earth2studio GenCast/GraphCast yield+del pattern.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- lr: 3e-4 → 8e-5 (paper stage 2-4 value)
- weight_decay: 1e-4 → 0.1 (paper value)
- Add linear warmup (1000 steps) + cosine decay LR schedule
- Save/restore scheduler state in checkpoints
- Log lr in progress line

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Paper stage 4: 8e-5 (1AR) → 8e-6 (2AR) → 8e-7 (3-8AR), not the
incorrect 8e-5/8e-5/8e-5/8e-6/8e-6/8e-7/8e-7/8e-7 we had.
Also thread lr_warmup_steps (800/400/100) through build_stage_cfg.
DEV_STAGES updated to mirror paper's LR ratios.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
_validation_loss() now returns (scalar, per_channel_mse[C]) alongside
the aggregate loss. Per-channel values are logged at each validation
step, mirroring StormCast's log_value(f'loss/valid/{ch}') convention.
Cheap: uses a single deterministic forward pass (latent=0), no ensemble.
Scheduler state saved/restored via save_checkpoint/load_checkpoint.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…alues

- training.torch_compile flag: compiles model before FSDP wrapping,
  skipped when ShardTensor active (matches CorrDiff/GraphCast pattern)
- Unwrap _orig_mod for save_checkpoint (OptimizedModule has no __len__)
- default.yaml: lr 3e-4→8e-5, weight_decay 1e-4→0.1 (paper Table A.2),
  add lr_warmup_steps/lr_min/torch_compile fields

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Was overriding lr=3e-4 and missing weight_decay/warmup fields.
Now: lr=8e-5, weight_decay=0.1, lr_warmup_steps=1000 (Stage 3 values).

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Matches DLWP/FCN-AFNO/diagnostic pattern from physicsnemo.utils.logging.
- LaunchLogger.initialize() called at trainer startup (no-op by default)
- train loop uses LaunchLogger context for loss/lr minibatch logging
- val loop logs val_loss + per-channel MSE via LaunchLogger
- use_wandb/use_mlflow/wandb_project config flags (all off by default)
- Fix _run_validation_metrics plot_power_spectra call (new signature)

Enable W&B: training.use_wandb=true training.wandb_project=my-project

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- add training.amp bool (default true) mirroring graphcast convention;
  autocast now reads self.amp instead of hardcoded is_available()
- remove _step_ensemble which was dead code (never called)

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- latent_dim default 16→32 (paper §2.3: z∈N(0,1)^32)
- eval CRPS switches to biased estimator (paper §4.1):
  deep-ensemble violates the independence assumption of the fair variant
- applies to both eval.py (e2s_crps fair=False) and trainer
  _run_validation_metrics (kcrps biased=True)

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…notes

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
@kashif kashif requested a review from RishikeshRanade as a code owner June 23, 2026 09:33
kashif added 8 commits June 23, 2026 09:46
Paper §4.1 / Figure 2g-h evaluates REV at the 99.99th percentile
(z ≈ 3.72); previous default stopped at p99.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Closes the Figure 4 gap from the paper gap analysis:
- tc_eval.py: AR rollout + earth2studio TCTrackerWuDuan per IC,
  IBTrACS ground-truth pairing, position error + track REV metrics
- metrics.py: plot_tc_position_error / plot_tc_track_rev (Fig 4a-b)
- arco.py: expose init_time in __getitem__ return dict
Requires: pip install cucim-cu12

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Richardson 2000: v_clim = max(0, μ-r), v_perf = μ*(1-r).
Prior code had them swapped, producing wrong REV denominator.

Also add frac_active >= 0.5 guard on position error per §4.3:
ensemble-mean position only counted when ≥50% of members
still forecast the cyclone.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
After each save, rank 0 deletes all but the 2 most recent
.mdlus and checkpoint.*.pt files to prevent disk exhaustion.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…=12)

Switch from a tiny 449K-param convolutional U-Net to physicsnemo's DiT.
Patchifying 721x1440 into 181x360=65k tokens before attention gives
16x memory reduction, enabling batch>1 and larger model capacity (~33M params).

z ~ N(0,I)^32 conditions all transformer layers via AdaLN-Zero
(passed as condition= to DiT), matching paper §2.3's global
conditional layer-norm exactly.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Remove unused torch.nn import from nn.py (F401)
- Replace append loop with list comprehension in eval.py (PERF401)
- Add patch_size/hidden_size/depth/num_heads to ModelConfig (extra=forbid)
- Remove UNet-only fields (hidden_channels, group_norm_groups)

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
PatchEmbed2D allocates pos_embed with floor(H/ps) but pads the input
at runtime, producing ceil(H/ps) tokens — causing a size mismatch on
any non-divisible grid (ERA5 721 with ps=4: 180 slots vs 181 tokens).

Pre-pad the input to the nearest patch multiple before DiT so its
internal padding path never fires, matching StormCast's practice of
always passing divisible resolutions. Crop the output back afterward.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant